import torch
import matplotlib.pyplot as plt

# -----------------散点图------------------
x = torch.linspace(0, 1, 20)
y = 3 * x + 2
# 加入散度
y += torch.normal(0, 0.2, (20,))

plt.plot(x, y, 'ro')

# -----------------预测线--------------------
w_predict = 0.1  # 预测的w
b_predict = 0.2  # 预测的b

y_predict = w_predict * x + b_predict
plt.plot(x, y_predict, 'b--')

# ------------------衡量预测线的准确率---------
e = (y_predict - y) ** 2  # 均方差 MSE
print(e)  # 该值越小越少
print(torch.mean(e)) # 将所有的均方差进行汇总,求平均值
print(torch.sum(e)) # 将所有的均方差进行汇总,求和

plt.show()
